import dgl
import dgl.nn.pytorch as dglnn
from dgl.dataloading import Sampler

import numpy as np

import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import argparse
import time
import os
import logging
import copy

from data.load_dataset import load_ogb, load_partitioned_graphs, load_complete_graphs
from utility import print_yellow, print_green, print_red, print_blue
from utility import print_graph_stats

from sampler import LocalNeighborSampler, GlobalNeighborSampler

from model.GraphSage import ClientServerGraphSage, evaluate as GraphSage_evaluate
from model.GIN import ClientServerGIN, evaluate as GIN_evaluate
from tqdm import tqdm

RANDOM_SEED = 1024

def construct_local_client_sampler(local_graph, complete_graph, fan_outs, args, sampling_device=th.device('cuda')):
    """
    Construct the local sampler for the client.
    """
    # Get all the training nodes
    train_nid_in_local_graph = local_graph.ndata['train_mask'].nonzero().squeeze()

    if sampling_device == th.device('cpu'):
        use_uva = False
    else:
        use_uva = True

    # Get training node original IDs in the complete graph
    local_training_node_nid = local_graph.ndata['_ID'][train_nid_in_local_graph]

    # Get all the local nodes' IDs in the complete graph
    local_node_id = local_graph.ndata['_ID']

    # Build the local mask
    local_mask = th.zeros(complete_graph.number_of_nodes(), dtype=th.bool)
    local_mask[local_node_id] = True
    assert sum(local_mask) == len(local_node_id)

    # Build local data loader
    local_sampler = LocalNeighborSampler(fan_outs)
    local_dataloader = dgl.dataloading.DataLoader(local_graph, train_nid_in_local_graph, local_sampler, 
                                                    device=sampling_device, num_workers=0, use_uva=use_uva, # enable GPU sampling
                                                    batch_size=args.batch_size, shuffle=True, drop_last=False)
    
    # Build global data loader
    if args.baseline_llcg:
        global_training_node_nid = complete_graph.ndata['train_mask'].nonzero().squeeze().to(sampling_device)
        llcg_fanouts = [-1 for _ in range(args.num_layers)]
        global_local_mask = th.zeros(complete_graph.number_of_nodes(), dtype=th.bool)
        global_local_mask[global_training_node_nid] = True
        global_sampler = GlobalNeighborSampler(llcg_fanouts, global_local_mask)
        global_dataloader = dgl.dataloading.DataLoader(complete_graph, global_training_node_nid, global_sampler, 
                                                        device=sampling_device, num_workers=0, use_uva=use_uva, # enable GPU sampling
                                                        batch_size=args.batch_size, shuffle=True, drop_last=False)
    else:
        global_sampler = GlobalNeighborSampler(fan_outs, local_mask)
        global_dataloader = dgl.dataloading.DataLoader(complete_graph, local_training_node_nid, global_sampler, 
                                                        device=sampling_device, num_workers=0, use_uva=use_uva, # enable GPU sampling
                                                        batch_size=args.batch_size, shuffle=True, drop_last=False)
    return local_mask, local_dataloader, global_dataloader

def synchronize_model_parameters(models, global_sample_clients, args):
    """
    Average the parameters of models across all clients to synchronize them.
    """
    if (len(global_sample_clients) == 0) or (not args.baseline_llcg):
        with th.no_grad():
            n = len(models)
            for param_key in models[0].state_dict().keys():
                # Averaging model parameters
                avg_param = sum([models[i].state_dict()[param_key] for i in range(n)]) / n
                for model in models:
                    model.state_dict()[param_key].copy_(avg_param)
    else:   # For the baseline llcg algorithm & global step
        assert len(global_sample_clients) == 1, "LLCG only supports one global sample client per step"
        
        trained_client_id = global_sample_clients[0]
        trained_model = models[trained_client_id]
        # Replacing other clients' parameters with the trained client's parameters
        with th.no_grad():
            for param_key in models[0].state_dict().keys():
                for client_id in range(args.num_parts):
                    if client_id != trained_client_id:
                        models[client_id].state_dict()[param_key].copy_(trained_model.state_dict()[param_key])



def create_client_logger(log_dir, client_id):
    client_logger = logging.getLogger(f'Client_{client_id}')
    client_logger.setLevel(logging.INFO)  # Set to INFO to avoid logging debug-level messages
    file_handler = logging.FileHandler(f'{log_dir}/client_{client_id}.log', mode='w')
    formatter = logging.Formatter('%(message)s')  # Simplified format without time or debug level
    file_handler.setFormatter(formatter)
    client_logger.addHandler(file_handler)
    return client_logger

def create_valid_logger(log_dir):
    validation_logger = logging.getLogger('Validation')
    validation_logger.setLevel(logging.INFO)
    validation_file_handler = logging.FileHandler(f'{log_dir}/validation.log', mode='w')
    validation_formatter = logging.Formatter('%(message)s')  # Simplified format
    validation_file_handler.setFormatter(validation_formatter)
    validation_logger.addHandler(validation_file_handler)
    return validation_logger

def distributed_training_simulation(partitioned_graphs, complete_graph, device, fan_outs, args, model_prototype):
    """
    Simulate distributed training on multiple clients sharing the same GPU.
    """
    # Initialize models for each client
    print("Initializing models")
    models = [copy.deepcopy(model_prototype).to(device) for _ in range(args.num_parts)]
    optimizers = [optim.Adam(model.parameters(), lr=args.lr, weight_decay=5e-4) for model in models]
    loss_fcns = [nn.CrossEntropyLoss().to(device) for _ in range(args.num_parts)]  # Separate loss function for each client

    # Initialize the dataloaders for each client
    local_masks = []
    local_dataloaders = []
    global_dataloaders = []
    local_iterloaders = []
    global_iterloaders = []
    for client_id in tqdm(range(args.num_parts)):
        local_graph = partitioned_graphs[client_id]
        local_mask, local_dataloader, global_dataloader = construct_local_client_sampler(local_graph, complete_graph, fan_outs, args, sampling_device=device)

        local_masks.append(local_mask)
        local_dataloaders.append(local_dataloader)
        global_dataloaders.append(global_dataloader)

        # Create the iterloader
        local_iterloader = iter(local_dataloader)
        global_iterloader = iter(global_dataloader)
        local_iterloaders.append(local_iterloader)
        global_iterloaders.append(global_iterloader)

    # Initialize the validation
    valid_fanouts = [-1 for _ in range(args.num_layers)]
    # We are resuing the local sampler for validation
    valid_sampler = LocalNeighborSampler(valid_fanouts)
    valid_nid = complete_graph.ndata['val_mask'].nonzero().squeeze()

    general_output_dir = f"output/{args.partition_method}_{args.model_type}/"

    # Init logging
    # Define log directory based on global sample interval and number of global sample clients per step
    log_dir = f"log/layer_{args.num_layers}_{args.dataset}_{args.num_parts}_{args.global_sample_interval}_{args.num_global_sample_clients_per_step}_{args.batch_size}_{args.lr}_{args.fan_out}"
    if args.baseline_llcg:
        log_dir = "llcg_" + log_dir
    elif args.baseline_pns:
        log_dir = "pns_" + log_dir
    elif args.baseline_dynamic_pns:
        log_dir = "dynamic_pns_" + log_dir
    log_dir = general_output_dir + log_dir
    os.makedirs(log_dir, exist_ok=True)
    # Create a list of loggers for all clients
    client_loggers = [create_client_logger(log_dir, i) for i in range(args.num_parts)]
    # Create a logger for validation
    validation_logger = create_valid_logger(log_dir)
    validation_logger.info(str(args))

    # Init Checkpoint
    saving_dir = f"checkpoint/layer_{args.num_layers}_{args.dataset}_{args.num_parts}_{args.global_sample_interval}_{args.num_global_sample_clients_per_step}_{args.batch_size}_{args.lr}_{args.fan_out}"
    if args.baseline_llcg:
        saving_dir = "llcg_" + saving_dir
    elif args.baseline_pns:
        saving_dir = "pns_" + saving_dir
    elif args.baseline_dynamic_pns:
        saving_dir = "dynamic_pns_" + saving_dir
    saving_dir = general_output_dir + saving_dir
    # If exist, clean the directory
    if os.path.exists(saving_dir):
        os.system(f"rm -rf {saving_dir}")
        os.makedirs(saving_dir)
    else:
        os.makedirs(saving_dir)
    
    init_sampling_interval = args.global_sample_interval
    sampling_interval = args.global_sample_interval


    if args.baseline_pns:
        pns_sampled_blocks = {} # Store the sampled blocks for each client during global sampling step

    if args.baseline_dynamic_pns:
        INIT_COEFFICIENT = None
        accumulated_runtime = 0
    for step in range(args.num_steps):
        if step % sampling_interval == 0:
            # Sample clients for global sampling
            is_global_sampling_step = True
            global_sample_clients = np.random.choice(args.num_parts, args.num_global_sample_clients_per_step, replace=False)
        else:
            is_global_sampling_step = False
            global_sample_clients = np.array([])

        avg_loss = 0
        for client_id in range(args.num_parts):
            step_start_time = time.time()

            # Check if the client is selected for global sampling
            if client_id in global_sample_clients:
                do_global_sampling = True 

                # Get the data from the global sampler
                try:
                    data = next(global_iterloaders[client_id])
                except StopIteration:
                    global_iterloaders[client_id] = iter(global_dataloaders[client_id])
                    data = next(global_iterloaders[client_id])

                if args.baseline_pns:
                    pns_sampled_blocks[client_id] = data
            else:
                do_global_sampling = False
                if not args.baseline_pns:   # in the baseline pns algorithm, we do not need to sample the local data
                    if args.baseline_llcg and global_sample_clients.size > 0:
                        assert global_sample_clients.size == 1
                        continue    # Skip the local training for the baseline llcg algorithm in global sampling step

                    # Get the data from the local sampler
                    try:
                        data = next(local_iterloaders[client_id])
                    except StopIteration:
                        local_iterloaders[client_id] = iter(local_dataloaders[client_id])
                        data = next(local_iterloaders[client_id])
            if args.baseline_pns:
                data = pns_sampled_blocks[client_id]
            blocks, transform_indices = data
            # Moving to the device
            for i in range(len(blocks)):
                blocks[i] = (blocks[i][0].to(device), blocks[i][1].to(device))

            # Extract features
            first_local_block, first_remote_block = blocks[0]
            local_feat_src = first_local_block.srcdata['features'].to(device)
            remote_feat_src = first_remote_block.srcdata['features'].to(device)

            # Extract labels
            last_local_block, last_remote_block = blocks[-1]
            assert last_remote_block.number_of_edges() == 0 # The last remote block should be empty
            labels = last_local_block.dstdata['labels'].to(device)

            # Compute loss and prediction
            optimizers[client_id].zero_grad()
            batch_pred = models[client_id](blocks, local_feat_src, remote_feat_src, transform_indices)
            loss = loss_fcns[client_id](batch_pred, labels)

            loss.backward()
            optimizers[client_id].step()

            avg_loss += loss.item()

            sampling_info = "G" if do_global_sampling else "L"

            if step % args.print_every == 0:
                print(
                    f"Step {step:04d} | "
                    f"Sampling = {sampling_info} |"
                    f"Client {client_id:04d} | "
                    f"Loss = {loss.item():.4f} | "
                    f"Step time = {(time.time() - step_start_time)*1000:.2f}ms"
                )

            client_loggers[client_id].info(
                f"Step {step:04d} | Sampling = {sampling_info} | Client {client_id:04d} | "
                f"Loss = {loss.item():.4f} | Step time = {(time.time() - step_start_time)*1000:.2f}ms"
            )

        avg_loss /= args.num_parts

        # Synchronize the model parameters across clients
        sync_start_time = time.time()
        synchronize_model_parameters(models, global_sample_clients, args)
        if step % args.print_every == 0:
            print(f"Step {step:04d} | Synchronization time = {(time.time() - sync_start_time)*1000:.2f}ms")
        validation_logger.info(f"Step {step:04d} | Synchronization time = {(time.time() - sync_start_time)*1000:.2f}ms")

        # Validate if the model is synchronized
        if args.validate_sync:
            print_yellow("Validating the synchronization of model parameters")
            parameter_keys = list(models[0].state_dict().keys())  # Get the list of parameter keys
            for key in parameter_keys:  # Check the first few parameters
                for client_id in range(1, args.num_parts):  # Start from 1 since 0 is the reference
                    assert th.allclose(models[client_id].state_dict()[key], models[0].state_dict()[key], atol=1e-6), f"Parameter {key} is not synchronized"
            print_yellow("Model parameters are synchronized")
        
        # Validate the model at every step
        valid_start_time = time.time()
        model = models[0]
        if args.model_type == 'graphsage':
            valid_acc = GraphSage_evaluate(model, complete_graph, valid_sampler, device, valid_nid, args)
        elif args.model_type == 'gin':
            valid_acc = GIN_evaluate(model, complete_graph, valid_sampler, device, valid_nid, args)
        if step % args.print_every == 0:
            print_yellow(f"Step {step:04d} | Validation Accuracy = {valid_acc:.4f}")

        if is_global_sampling_step:
            step_status = "G"
        else:
            step_status = "L"
        validation_logger.info(f"Step {step:04d} | Validation Accuracy = {valid_acc:.4f} | Validation time = {(time.time() - valid_start_time)*1000:.2f}ms | Sampling = {step_status} | Sampling interval = {sampling_interval}")

        if args.baseline_dynamic_pns:   # We should do a dynamic sampling interval update
            if INIT_COEFFICIENT is None:    # Initialize the coefficient
                INIT_COEFFICIENT = ((0 + 1)/ avg_loss) ** (1/3)
            
            if args.dataset == 'reddit' and args.num_parts == 10:
                local_step_time = 14.41
                global_step_time = 335.55
            elif args.dataset == 'ogbn-products' and args.num_parts == 20:
                local_step_time = 11.76
                global_step_time = 400.73
            else:
                raise NotImplementedError("Dynamic sampling interval update is not implemented for this dataset and number of clients")

            if is_global_sampling_step:
                accumulated_runtime += global_step_time
            else:
                accumulated_runtime += local_step_time

            sampling_interval = int(init_sampling_interval * INIT_COEFFICIENT * (avg_loss / accumulated_runtime) ** (1/3))
            assert sampling_interval <= init_sampling_interval, "Sampling interval should not exceed the initial sampling interval"
            if sampling_interval < 1:
                sampling_interval = 1

        if step > 0 and step % args.save_every == 0:
            # Save the model
            model = models[0]

            th.save(model.state_dict(), f"{saving_dir}/model_step_{step}.pth")
            print_yellow(f"Model saved at step {step}")


if __name__ == '__main__':
    # Set seed
    th.manual_seed(RANDOM_SEED)
    np.random.seed(RANDOM_SEED)


    argparser = argparse.ArgumentParser("Train the model on a dataset")
    argparser.add_argument('--dataset', type=str, default='ogbn-products',
                           help='datasets: ogbn-arxiv, ogbn-products, reddit, citeseer, flickr')
    argparser.add_argument('--num_parts', type=int, default=10,
                           help='number of partitions')
    argparser.add_argument('--partition_method', type=str, default='metis',
                           help='partition method: metis')
    argparser.add_argument('--device', type=str, default='gpu',
                            help='training device: gpu, cpu')
    argparser.add_argument('--fan_out', type=str, default='15,10',
                            help='fan out for neighbor sampling, from 1-hop to n-hop neighbors, separated by comma')
    argparser.add_argument('--num_steps', type=int, default=4001)
    argparser.add_argument('--valid_batch_size', type=int, default=2000)
    argparser.add_argument('--save_every', type=int, default=5000, help='save model every n steps')
    argparser.add_argument('--print_every', type=int, default=50, help='print every n steps')

    # Model parameters: no need to input via arguments
    argparser.add_argument('--num_layers', type=int, default=2)
    argparser.add_argument('--dropout', type=float, default=0.5)
    argparser.add_argument('--num_hidden', type=int, default=256)

    # For debugging
    argparser.add_argument('--validate_sync', action='store_true',
                            help='whether to validate the synchronization of model parameters')
    
    # Fed Training Setting; Do Grid search
    argparser.add_argument('--global_sample_interval', type=int, default=1,
                            help='global sample interval')
    argparser.add_argument('--num_global_sample_clients_per_step', type=int, default=10,
                            help='number of global sample clients per step')
    argparser.add_argument('--lr', type=float, default=0.001)
    argparser.add_argument('--batch_size', type=int, default=256)
    argparser.add_argument('--baseline_llcg', action='store_true',
                            help='whether to run the baseline llcg algorithm')
    argparser.add_argument('--baseline_pns', action='store_true',
                            help='whether to run the baseline periodic neighbour sampling algorithm')
    argparser.add_argument('--baseline_dynamic_pns', action='store_true',
                            help='whether to run the baseline *dynamic* periodic neighbour sampling algorithm')
    
    argparser.add_argument('--model_type', type=str, default='graphsage',
                            help='model type: graphsage, gin')

    # Parse the arguments
    args = argparser.parse_args()

    if args.baseline_llcg:
        assert args.num_global_sample_clients_per_step == 1, "LLCG only supports one global sample client per step"
    if args.baseline_pns:
        assert args.num_global_sample_clients_per_step == args.num_parts, "PNS requires all clients to be global sample clients"

    assert args.num_parts >= args.num_global_sample_clients_per_step

    # Print the arguments
    print_red(args)

    fan_outs = [int(fanout) for fanout in args.fan_out.split(",")]
    assert len(fan_outs) == args.num_layers

    # Setting up env
    if args.device == 'gpu':
        if not th.cuda.is_available():
            raise Exception("No GPU found")
        print_red("Using GPU")
        device = th.device('cuda')
    else:
        print_red("Using CPU")
        device = th.device('cpu')

    # Load complete graph
    complete_graph = load_complete_graphs(args.dataset)

    # Get the feature dimension
    in_feats = complete_graph.ndata['features'].shape[1]
    print("Feature dimension:", in_feats)

    # Get the number of classes
    n_classes = len(th.unique(complete_graph.ndata['labels'][th.logical_not(th.isnan(complete_graph.ndata['labels']))]))
    print("Number of classes:", n_classes)

    # Load partitioned graphs
    output_dir = "partitioned_dataset"
    saving_graph_name = args.dataset + "_" + args.partition_method + "_" + str(args.num_parts)
    path_to_partitioned_dataset = output_dir + "/" + saving_graph_name + ".bin"

    partitioned_graphs = load_partitioned_graphs(path_to_partitioned_dataset)

    # Build model prototype
    if args.model_type == 'graphsage':
        model_prototype = ClientServerGraphSage(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout)
    elif args.model_type == 'gin':
        model_prototype = ClientServerGIN(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout, aggregator_type='mean')
    # Initialize model parameters
    def weights_init(m):
        if isinstance(m, nn.Linear):
            th.nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                th.nn.init.zeros_(m.bias)

    model_prototype.apply(weights_init)

    

    # Simuate distributed training over clients
    distributed_training_simulation(partitioned_graphs, complete_graph, device, fan_outs, args, model_prototype)

    
